# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
######################## write mindrecord example ########################
Write mindrecord by dataloader dictionary:
python writer.py --mindrecord_script /YourScriptPath ...
"""
import argparse
import os
import time
from importlib import import_module
from multiprocessing import Pool
import json

from mindspore.mindrecord import FileWriter
from .graph_map_schema import GraphMapSchema


def exec_task(task_id, parallel_writer=True):
    """
    Execute task with specified task id
    """
    print("exec task {}, parallel: {} ...".format(task_id, parallel_writer))
    imagenet_iter = mindrecord_dict_data(task_id)
    batch_size = 512
    transform_count = 0
    while True:
        data_list = []
        try:
            for _ in range(batch_size):
                data = imagenet_iter.__next__()
                if 'dst_id' in data:
                    data = graph_map_schema.transform_edge(data)
                else:
                    data = graph_map_schema.transform_node(data)
                data_list.append(data)
                transform_count += 1
            writer.write_raw_data(data_list, parallel_writer=parallel_writer)
            print("transformed {} record...".format(transform_count))
        except StopIteration:
            if data_list:
                writer.write_raw_data(data_list, parallel_writer=parallel_writer)
                print("transformed {} record...".format(transform_count))
            break


def read_args():
    """
    read args
    """
    parser = argparse.ArgumentParser(description='Mind record writer')
    parser.add_argument('--mindrecord_script', type=str, default="template",
                        help='path where script is saved')

    parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord/xyz",
                        help='written file name prefix')

    parser.add_argument('--mindrecord_partitions', type=int, default=1,
                        help='number of written files')

    parser.add_argument('--mindrecord_header_size_by_bit', type=int, default=24,
                        help='mindrecord file header size')

    parser.add_argument('--mindrecord_page_size_by_bit', type=int, default=25,
                        help='mindrecord file page size')

    parser.add_argument('--mindrecord_workers', type=int, default=8,
                        help='number of parallel workers')

    parser.add_argument('--num_node_tasks', type=int, default=1,
                        help='number of node tasks')

    parser.add_argument('--num_edge_tasks', type=int, default=1,
                        help='number of node tasks')

    parser.add_argument('--graph_api_args', type=str, default="/tmp/nodes.csv:/tmp/edges.csv",
                        help='nodes and edges dataloader file, csv format with header.')

    ret_args = parser.parse_args()

    return ret_args


def init_writer(mr_schema):
    """
    init writer
    """
    print("Init writer  ...")
    mr_writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions)

    # set the header size
    if args.mindrecord_header_size_by_bit != 24:
        header_size = 1 << args.mindrecord_header_size_by_bit
        mr_writer.set_header_size(header_size)

    # set the page size
    if args.mindrecord_page_size_by_bit != 25:
        page_size = 1 << args.mindrecord_page_size_by_bit
        mr_writer.set_page_size(page_size)

    # create the schema
    mr_writer.add_schema(mr_schema, "mindrecord_graph_schema")

    # open file and set header
    mr_writer.open_and_set_header()

    return mr_writer


def run_parallel_workers(num_tasks):
    """
    run parallel workers
    """
    # set number of workers
    num_workers = args.mindrecord_workers

    task_list = list(range(num_tasks))

    if num_workers > num_tasks:
        num_workers = num_tasks

    if os.name == 'nt':
        for window_task_id in task_list:
            exec_task(window_task_id, False)
    elif num_tasks > 1:
        with Pool(num_workers) as p:
            p.map(exec_task, task_list)
    else:
        exec_task(0, False)


def write_meta(graph_meta_file):
    """
    dump graph meta information into file
    """
    meta_file_path = f"{args.mindrecord_file}_graph_meta.json"
    with open(meta_file_path, "w") as meta_file:
        json.dump(graph_meta_file.json, meta_file)


if __name__ == "__main__":
    args = read_args()
    print(args)

    start_time = time.time()

    # pass mr_api arguments
    os.environ['graph_api_args'] = args.graph_api_args

    try:
        mr_api = import_module(args.mindrecord_script)
    except ModuleNotFoundError:
        raise RuntimeError(f"Unknown module path: {args.mindrecord_script}")

    # init graph schema
    graph_map_schema = GraphMapSchema()

    node_feature_names, feature_data_types, feature_shapes = mr_api.node_profile
    graph_map_schema.add_node_features_schema(node_feature_names, feature_data_types, feature_shapes)

    edge_feature_names, feature_data_types, feature_shapes = mr_api.edge_profile
    graph_map_schema.add_edge_features_schema(edge_feature_names, feature_data_types, feature_shapes)

    graph_meta_variants = mr_api.variants()
    graph_map_schema.add_meta_variants(graph_meta_variants)

    graph_schema = graph_map_schema.schema
    graph_meta = graph_map_schema.meta

    # init writer
    writer = init_writer(graph_schema)

    # write nodes dataloader
    mindrecord_dict_data = mr_api.yield_nodes
    run_parallel_workers(args.num_node_tasks)

    # write edges dataloader
    mindrecord_dict_data = mr_api.yield_edges
    run_parallel_workers(args.num_edge_tasks)

    # write meta dataloader
    write_meta(graph_meta)

    # writer wrap up
    ret = writer.commit()

    end_time = time.time()
    print("--------------------------------------------")
    print("END. Total time: {}".format(end_time - start_time))
    print("--------------------------------------------")
